#' ---
#' title: Reproduce Figure 2 (Political Ad Tone)
#' date: 2024-10-01
#' version: 1.0
#' ---

library(tidyverse)
library(promptr)
library(patchwork)
set.seed(42)

# Reproduce figures and tables from saved data (TRUE) 
# Or new OpenAI prompts? (FALSE)
from_file <- TRUE

## Load Data ----------------------------

d <- read_csv('application2-carlson-montgomery-2017.csv')

# instructions for one-shot prompt
instructions <- 'Decide whether the tone of a political advertisement is Positive, Neutral, or Negative.'

# example completion for one-shot prompt
examples <- d |> 
  filter(ids == 7040) |> 
  mutate(label = 'Negative') |> 
  select(text, label)

## Functions to compute tone from LLM output and plot results -------

get_score <- function(out){
  p <- out |>
    lapply(mutate, token = str_to_lower(str_trim(token))) |>
    lapply(summarize,
           negative = sum(probability[token=='negative']),
           neutral = sum(probability[token=='neutral']),
           positive = sum(probability[token=='positive'])) |>
    lapply(summarize,score=positive-negative) |>
    unlist()
  # bind_rows() |> 
  # princomp()
  
  return(-1 * p)
  # return(-1 * p$scores[,1])
}

plot_results <- function(model){
  
  correlation <- cor(d$score, 
                     d$tone,
                     use = 'pairwise.complete.obs')
  
  ggplot(data = d,
         mapping = aes(
           x = tone,
           y = score
         )) +
    geom_jitter(width = 0.15, alpha = 0.5) +
    labs(x = 'Expert Tone',
         y = paste0(model, ' Tone'),
         title = paste0(model, ' ',
                        '(\U03C1 = ',
                        round(correlation, 2),
                        ')')) +
    theme_bw() +
    geom_smooth(method = 'lm', color = 'black')
}


## GPT-3 One-Shot --------------------------

if(!from_file){
  prompts <- lapply(d$text, 
                    format_prompt,
                    instructions = instructions,
                    examples = examples)
  
  # split up the requests to handle TPM rate limits
  out1 <- complete_prompt(prompt = prompts[1:500],
                          model = 'davinci-002')
  Sys.sleep(60)
  out2 <- complete_prompt(prompt = prompts[501:935],
                          model = 'davinci-002')
  
  out <- c(out1,out2)
  
} else{
  load(file = 'application2-one-shot-gpt-3.RData')
}

d$score <- get_score(out)
d$gpt_3_label <- out |> 
  lapply(slice_max, probability, with_ties = FALSE) |> 
  lapply(pull, token) |> 
  unlist()

p2 <- plot_results('GPT-3')

p2

## GPT-4 One-Shot ------------------------

if(!from_file){
  prompts <- lapply(d$text, 
                    format_chat,
                    instructions = instructions,
                    examples = examples)
  
  out <- complete_chat(prompts, 
                       model = 'gpt-4-turbo-preview',
                       parallel = TRUE)
  
} else{
  load(file = 'application2-one-shot-gpt-4.RData')
}

d$score <- get_score(out)
d$gpt_4_label <- out |> 
  lapply(slice_max, probability, with_ties = FALSE) |> 
  lapply(pull, token) |> 
  unlist()

p3 <- plot_results('GPT-4')

p3

## Compare to crowd-coders and build Figure 2 ------------

d$score <- d$alphas
p1 <- plot_results('Crowd-Coded')
p1

(p1 + p2) / (p3 + patchwork::plot_spacer())
ggsave(filename = 'figure2.png',
       width = 8, height = 8)
